{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "25419d99-c8f7-4843-be39-7f21c7dd5266",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 深度计算\n",
    "\n",
    "目录\n",
    "\n",
    "- 数值稳定性和模型初始化\n",
    "- 层和块\n",
    "- 参数管理\n",
    "- 自定义层\n",
    "- 读写文件\n",
    "- GPU"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31aec33a-26b5-434f-8227-bcbea1b2fe10",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "slideshow": {
     "slide_type": "-"
    },
    "tags": []
   },
   "source": [
    "## 数值稳定性和模型初始化\n",
    "\n",
    "梯度消失"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7f3e77b-9245-4c83-bad0-24eca173455e",
   "metadata": {
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\n",
       "<svg height=\"166.978125pt\" version=\"1.1\" viewBox=\"0 0 288.403125 166.978125\" width=\"288.403125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2022-03-08T18:45:41.490781</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.3.3, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 166.978125 \n",
       "L 288.403125 166.978125 \n",
       "L 288.403125 0 \n",
       "L 0 0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 30.103125 143.1 \n",
       "L 281.203125 143.1 \n",
       "L 281.203125 7.2 \n",
       "L 30.103125 7.2 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 48.695149 143.1 \n",
       "L 48.695149 7.2 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"m3261089625\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"48.695149\" xlink:href=\"#m3261089625\" y=\"143.1\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- −7.5 -->\n",
       "      <g transform=\"translate(36.553743 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path d=\"M 10.59375 35.5 \n",
       "L 73.1875 35.5 \n",
       "L 73.1875 27.203125 \n",
       "L 10.59375 27.203125 \n",
       "z\n",
       "\" id=\"DejaVuSans-8722\"/>\n",
       "        <path d=\"M 8.203125 72.90625 \n",
       "L 55.078125 72.90625 \n",
       "L 55.078125 68.703125 \n",
       "L 28.609375 0 \n",
       "L 18.3125 0 \n",
       "L 43.21875 64.59375 \n",
       "L 8.203125 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-55\"/>\n",
       "        <path d=\"M 10.6875 12.40625 \n",
       "L 21 12.40625 \n",
       "L 21 0 \n",
       "L 10.6875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-46\"/>\n",
       "        <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-53\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-55\"/>\n",
       "       <use x=\"147.412109\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"179.199219\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 84.587088 143.1 \n",
       "L 84.587088 7.2 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"84.587088\" xlink:href=\"#m3261089625\" y=\"143.1\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- −5.0 -->\n",
       "      <g transform=\"translate(72.445682 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-48\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"147.412109\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"179.199219\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 120.479027 143.1 \n",
       "L 120.479027 7.2 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"120.479027\" xlink:href=\"#m3261089625\" y=\"143.1\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- −2.5 -->\n",
       "      <g transform=\"translate(108.337621 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-50\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"147.412109\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"179.199219\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 156.370967 143.1 \n",
       "L 156.370967 7.2 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.370967\" xlink:href=\"#m3261089625\" y=\"143.1\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 0.0 -->\n",
       "      <g transform=\"translate(148.419404 157.698438)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 192.262906 143.1 \n",
       "L 192.262906 7.2 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"192.262906\" xlink:href=\"#m3261089625\" y=\"143.1\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 2.5 -->\n",
       "      <g transform=\"translate(184.311343 157.698438)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_6\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 228.154845 143.1 \n",
       "L 228.154845 7.2 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"228.154845\" xlink:href=\"#m3261089625\" y=\"143.1\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_6\">\n",
       "      <!-- 5.0 -->\n",
       "      <g transform=\"translate(220.203282 157.698438)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_7\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 264.046784 143.1 \n",
       "L 264.046784 7.2 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"264.046784\" xlink:href=\"#m3261089625\" y=\"143.1\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 7.5 -->\n",
       "      <g transform=\"translate(256.095221 157.698438)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-55\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 30.103125 136.964174 \n",
       "L 281.203125 136.964174 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"mb28606d39a\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#mb28606d39a\" y=\"136.964174\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.0 -->\n",
       "      <g transform=\"translate(7.2 140.763392)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_17\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 30.103125 112.237629 \n",
       "L 281.203125 112.237629 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_18\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#mb28606d39a\" y=\"112.237629\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.2 -->\n",
       "      <g transform=\"translate(7.2 116.036848)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_19\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 30.103125 87.511085 \n",
       "L 281.203125 87.511085 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_20\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#mb28606d39a\" y=\"87.511085\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_10\">\n",
       "      <!-- 0.4 -->\n",
       "      <g transform=\"translate(7.2 91.310304)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-52\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-52\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_21\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 30.103125 62.784541 \n",
       "L 281.203125 62.784541 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_22\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#mb28606d39a\" y=\"62.784541\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_11\">\n",
       "      <!-- 0.6 -->\n",
       "      <g transform=\"translate(7.2 66.583759)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path d=\"M 33.015625 40.375 \n",
       "Q 26.375 40.375 22.484375 35.828125 \n",
       "Q 18.609375 31.296875 18.609375 23.390625 \n",
       "Q 18.609375 15.53125 22.484375 10.953125 \n",
       "Q 26.375 6.390625 33.015625 6.390625 \n",
       "Q 39.65625 6.390625 43.53125 10.953125 \n",
       "Q 47.40625 15.53125 47.40625 23.390625 \n",
       "Q 47.40625 31.296875 43.53125 35.828125 \n",
       "Q 39.65625 40.375 33.015625 40.375 \n",
       "z\n",
       "M 52.59375 71.296875 \n",
       "L 52.59375 62.3125 \n",
       "Q 48.875 64.0625 45.09375 64.984375 \n",
       "Q 41.3125 65.921875 37.59375 65.921875 \n",
       "Q 27.828125 65.921875 22.671875 59.328125 \n",
       "Q 17.53125 52.734375 16.796875 39.40625 \n",
       "Q 19.671875 43.65625 24.015625 45.921875 \n",
       "Q 28.375 48.1875 33.59375 48.1875 \n",
       "Q 44.578125 48.1875 50.953125 41.515625 \n",
       "Q 57.328125 34.859375 57.328125 23.390625 \n",
       "Q 57.328125 12.15625 50.6875 5.359375 \n",
       "Q 44.046875 -1.421875 33.015625 -1.421875 \n",
       "Q 20.359375 -1.421875 13.671875 8.265625 \n",
       "Q 6.984375 17.96875 6.984375 36.375 \n",
       "Q 6.984375 53.65625 15.1875 63.9375 \n",
       "Q 23.390625 74.21875 37.203125 74.21875 \n",
       "Q 40.921875 74.21875 44.703125 73.484375 \n",
       "Q 48.484375 72.75 52.59375 71.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-54\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-54\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_23\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 30.103125 38.057996 \n",
       "L 281.203125 38.057996 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_24\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#mb28606d39a\" y=\"38.057996\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_12\">\n",
       "      <!-- 0.8 -->\n",
       "      <g transform=\"translate(7.2 41.857215)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path d=\"M 31.78125 34.625 \n",
       "Q 24.75 34.625 20.71875 30.859375 \n",
       "Q 16.703125 27.09375 16.703125 20.515625 \n",
       "Q 16.703125 13.921875 20.71875 10.15625 \n",
       "Q 24.75 6.390625 31.78125 6.390625 \n",
       "Q 38.8125 6.390625 42.859375 10.171875 \n",
       "Q 46.921875 13.96875 46.921875 20.515625 \n",
       "Q 46.921875 27.09375 42.890625 30.859375 \n",
       "Q 38.875 34.625 31.78125 34.625 \n",
       "z\n",
       "M 21.921875 38.8125 \n",
       "Q 15.578125 40.375 12.03125 44.71875 \n",
       "Q 8.5 49.078125 8.5 55.328125 \n",
       "Q 8.5 64.0625 14.71875 69.140625 \n",
       "Q 20.953125 74.21875 31.78125 74.21875 \n",
       "Q 42.671875 74.21875 48.875 69.140625 \n",
       "Q 55.078125 64.0625 55.078125 55.328125 \n",
       "Q 55.078125 49.078125 51.53125 44.71875 \n",
       "Q 48 40.375 41.703125 38.8125 \n",
       "Q 48.828125 37.15625 52.796875 32.3125 \n",
       "Q 56.78125 27.484375 56.78125 20.515625 \n",
       "Q 56.78125 9.90625 50.3125 4.234375 \n",
       "Q 43.84375 -1.421875 31.78125 -1.421875 \n",
       "Q 19.734375 -1.421875 13.25 4.234375 \n",
       "Q 6.78125 9.90625 6.78125 20.515625 \n",
       "Q 6.78125 27.484375 10.78125 32.3125 \n",
       "Q 14.796875 37.15625 21.921875 38.8125 \n",
       "z\n",
       "M 18.3125 54.390625 \n",
       "Q 18.3125 48.734375 21.84375 45.5625 \n",
       "Q 25.390625 42.390625 31.78125 42.390625 \n",
       "Q 38.140625 42.390625 41.71875 45.5625 \n",
       "Q 45.3125 48.734375 45.3125 54.390625 \n",
       "Q 45.3125 60.0625 41.71875 63.234375 \n",
       "Q 38.140625 66.40625 31.78125 66.40625 \n",
       "Q 25.390625 66.40625 21.84375 63.234375 \n",
       "Q 18.3125 60.0625 18.3125 54.390625 \n",
       "z\n",
       "\" id=\"DejaVuSans-56\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-56\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_6\">\n",
       "     <g id=\"line2d_25\">\n",
       "      <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 30.103125 13.331452 \n",
       "L 281.203125 13.331452 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_26\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#mb28606d39a\" y=\"13.331452\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_13\">\n",
       "      <!-- 1.0 -->\n",
       "      <g transform=\"translate(7.2 17.130671)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-49\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_27\">\n",
       "    <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 41.516761 136.922713 \n",
       "L 65.923277 136.737562 \n",
       "L 77.4087 136.460971 \n",
       "L 86.022764 136.050337 \n",
       "L 93.201152 135.464702 \n",
       "L 98.943864 134.74049 \n",
       "L 103.250896 133.981285 \n",
       "L 107.557928 132.971398 \n",
       "L 110.429284 132.122009 \n",
       "L 113.30064 131.100784 \n",
       "L 116.171995 129.87703 \n",
       "L 119.043348 128.416404 \n",
       "L 121.914704 126.681307 \n",
       "L 124.786059 124.63175 \n",
       "L 127.657415 122.226792 \n",
       "L 130.528769 119.426742 \n",
       "L 133.400125 116.19615 \n",
       "L 136.271481 112.50763 \n",
       "L 139.142835 108.346262 \n",
       "L 142.014191 103.714212 \n",
       "L 144.885546 98.634876 \n",
       "L 147.756901 93.155696 \n",
       "L 152.063934 84.351342 \n",
       "L 164.985032 57.139929 \n",
       "L 167.856387 51.66075 \n",
       "L 170.727742 46.581409 \n",
       "L 173.599098 41.949358 \n",
       "L 176.470452 37.787991 \n",
       "L 179.341808 34.099477 \n",
       "L 182.213164 30.868893 \n",
       "L 185.084518 28.06884 \n",
       "L 187.955874 25.663873 \n",
       "L 190.827229 23.614316 \n",
       "L 193.698585 21.879221 \n",
       "L 196.569941 20.418595 \n",
       "L 199.441293 19.194841 \n",
       "L 202.312649 18.173618 \n",
       "L 206.619681 16.955405 \n",
       "L 210.926713 16.036702 \n",
       "L 215.233745 15.346977 \n",
       "L 220.976457 14.689795 \n",
       "L 228.154845 14.158904 \n",
       "L 236.768909 13.786942 \n",
       "L 248.254332 13.536533 \n",
       "L 266.918136 13.38742 \n",
       "L 269.789489 13.377273 \n",
       "L 269.789489 13.377273 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"line2d_28\">\n",
       "    <path clip-path=\"url(#p6bb7b3c6f2)\" d=\"M 41.516761 136.922727 \n",
       "L 65.923277 136.737978 \n",
       "L 77.4087 136.46302 \n",
       "L 86.022764 136.057092 \n",
       "L 93.201152 135.482889 \n",
       "L 98.943864 134.780485 \n",
       "L 103.250896 134.053253 \n",
       "L 107.557928 133.100346 \n",
       "L 111.86496 131.864446 \n",
       "L 114.736316 130.852567 \n",
       "L 117.607672 129.66889 \n",
       "L 120.479027 128.297061 \n",
       "L 123.35038 126.724966 \n",
       "L 126.221736 124.947729 \n",
       "L 129.093091 122.971391 \n",
       "L 133.400125 119.684796 \n",
       "L 144.885546 110.517935 \n",
       "L 147.756901 108.678954 \n",
       "L 150.628256 107.260083 \n",
       "L 152.063934 106.741126 \n",
       "L 153.499611 106.363024 \n",
       "L 154.935289 106.133133 \n",
       "L 156.370967 106.055993 \n",
       "L 157.806644 106.133135 \n",
       "L 159.242322 106.363026 \n",
       "L 160.677999 106.741126 \n",
       "L 162.113677 107.260083 \n",
       "L 164.985032 108.678954 \n",
       "L 167.856387 110.517935 \n",
       "L 170.727742 112.656506 \n",
       "L 176.470452 117.34553 \n",
       "L 180.777486 120.817036 \n",
       "L 185.084518 123.983526 \n",
       "L 187.955874 125.861917 \n",
       "L 190.827229 127.536562 \n",
       "L 193.698585 129.007383 \n",
       "L 196.569941 130.283295 \n",
       "L 199.441293 131.378861 \n",
       "L 203.748325 132.722848 \n",
       "L 208.055361 133.763369 \n",
       "L 212.362393 134.559937 \n",
       "L 218.105105 135.33127 \n",
       "L 223.847809 135.859872 \n",
       "L 231.026204 136.289615 \n",
       "L 241.075944 136.627342 \n",
       "L 255.43272 136.839828 \n",
       "L 269.789489 136.91837 \n",
       "L 269.789489 136.91837 \n",
       "\" style=\"fill:none;stroke:#bf00bf;stroke-dasharray:5.55,2.4;stroke-dashoffset:0;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 30.103125 143.1 \n",
       "L 30.103125 7.2 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 281.203125 143.1 \n",
       "L 281.203125 7.2 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 30.103125 143.1 \n",
       "L 281.203125 143.1 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 30.103125 7.2 \n",
       "L 281.203125 7.2 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"legend_1\">\n",
       "    <g id=\"patch_7\">\n",
       "     <path d=\"M 37.103125 44.55625 \n",
       "L 111.228125 44.55625 \n",
       "Q 113.228125 44.55625 113.228125 42.55625 \n",
       "L 113.228125 14.2 \n",
       "Q 113.228125 12.2 111.228125 12.2 \n",
       "L 37.103125 12.2 \n",
       "Q 35.103125 12.2 35.103125 14.2 \n",
       "L 35.103125 42.55625 \n",
       "Q 35.103125 44.55625 37.103125 44.55625 \n",
       "z\n",
       "\" style=\"fill:#ffffff;opacity:0.8;stroke:#cccccc;stroke-linejoin:miter;\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_29\">\n",
       "     <path d=\"M 39.103125 20.298438 \n",
       "L 59.103125 20.298438 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_30\"/>\n",
       "    <g id=\"text_14\">\n",
       "     <!-- sigmoid -->\n",
       "     <g transform=\"translate(67.103125 23.798438)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path d=\"M 44.28125 53.078125 \n",
       "L 44.28125 44.578125 \n",
       "Q 40.484375 46.53125 36.375 47.5 \n",
       "Q 32.28125 48.484375 27.875 48.484375 \n",
       "Q 21.1875 48.484375 17.84375 46.4375 \n",
       "Q 14.5 44.390625 14.5 40.28125 \n",
       "Q 14.5 37.15625 16.890625 35.375 \n",
       "Q 19.28125 33.59375 26.515625 31.984375 \n",
       "L 29.59375 31.296875 \n",
       "Q 39.15625 29.25 43.1875 25.515625 \n",
       "Q 47.21875 21.78125 47.21875 15.09375 \n",
       "Q 47.21875 7.46875 41.1875 3.015625 \n",
       "Q 35.15625 -1.421875 24.609375 -1.421875 \n",
       "Q 20.21875 -1.421875 15.453125 -0.5625 \n",
       "Q 10.6875 0.296875 5.421875 2 \n",
       "L 5.421875 11.28125 \n",
       "Q 10.40625 8.6875 15.234375 7.390625 \n",
       "Q 20.0625 6.109375 24.8125 6.109375 \n",
       "Q 31.15625 6.109375 34.5625 8.28125 \n",
       "Q 37.984375 10.453125 37.984375 14.40625 \n",
       "Q 37.984375 18.0625 35.515625 20.015625 \n",
       "Q 33.0625 21.96875 24.703125 23.78125 \n",
       "L 21.578125 24.515625 \n",
       "Q 13.234375 26.265625 9.515625 29.90625 \n",
       "Q 5.8125 33.546875 5.8125 39.890625 \n",
       "Q 5.8125 47.609375 11.28125 51.796875 \n",
       "Q 16.75 56 26.8125 56 \n",
       "Q 31.78125 56 36.171875 55.265625 \n",
       "Q 40.578125 54.546875 44.28125 53.078125 \n",
       "z\n",
       "\" id=\"DejaVuSans-115\"/>\n",
       "       <path d=\"M 9.421875 54.6875 \n",
       "L 18.40625 54.6875 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 64.59375 \n",
       "L 9.421875 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-105\"/>\n",
       "       <path d=\"M 45.40625 27.984375 \n",
       "Q 45.40625 37.75 41.375 43.109375 \n",
       "Q 37.359375 48.484375 30.078125 48.484375 \n",
       "Q 22.859375 48.484375 18.828125 43.109375 \n",
       "Q 14.796875 37.75 14.796875 27.984375 \n",
       "Q 14.796875 18.265625 18.828125 12.890625 \n",
       "Q 22.859375 7.515625 30.078125 7.515625 \n",
       "Q 37.359375 7.515625 41.375 12.890625 \n",
       "Q 45.40625 18.265625 45.40625 27.984375 \n",
       "z\n",
       "M 54.390625 6.78125 \n",
       "Q 54.390625 -7.171875 48.1875 -13.984375 \n",
       "Q 42 -20.796875 29.203125 -20.796875 \n",
       "Q 24.46875 -20.796875 20.265625 -20.09375 \n",
       "Q 16.0625 -19.390625 12.109375 -17.921875 \n",
       "L 12.109375 -9.1875 \n",
       "Q 16.0625 -11.328125 19.921875 -12.34375 \n",
       "Q 23.78125 -13.375 27.78125 -13.375 \n",
       "Q 36.625 -13.375 41.015625 -8.765625 \n",
       "Q 45.40625 -4.15625 45.40625 5.171875 \n",
       "L 45.40625 9.625 \n",
       "Q 42.625 4.78125 38.28125 2.390625 \n",
       "Q 33.9375 0 27.875 0 \n",
       "Q 17.828125 0 11.671875 7.65625 \n",
       "Q 5.515625 15.328125 5.515625 27.984375 \n",
       "Q 5.515625 40.671875 11.671875 48.328125 \n",
       "Q 17.828125 56 27.875 56 \n",
       "Q 33.9375 56 38.28125 53.609375 \n",
       "Q 42.625 51.21875 45.40625 46.390625 \n",
       "L 45.40625 54.6875 \n",
       "L 54.390625 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-103\"/>\n",
       "       <path d=\"M 52 44.1875 \n",
       "Q 55.375 50.25 60.0625 53.125 \n",
       "Q 64.75 56 71.09375 56 \n",
       "Q 79.640625 56 84.28125 50.015625 \n",
       "Q 88.921875 44.046875 88.921875 33.015625 \n",
       "L 88.921875 0 \n",
       "L 79.890625 0 \n",
       "L 79.890625 32.71875 \n",
       "Q 79.890625 40.578125 77.09375 44.375 \n",
       "Q 74.3125 48.1875 68.609375 48.1875 \n",
       "Q 61.625 48.1875 57.5625 43.546875 \n",
       "Q 53.515625 38.921875 53.515625 30.90625 \n",
       "L 53.515625 0 \n",
       "L 44.484375 0 \n",
       "L 44.484375 32.71875 \n",
       "Q 44.484375 40.625 41.703125 44.40625 \n",
       "Q 38.921875 48.1875 33.109375 48.1875 \n",
       "Q 26.21875 48.1875 22.15625 43.53125 \n",
       "Q 18.109375 38.875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.1875 51.21875 25.484375 53.609375 \n",
       "Q 29.78125 56 35.6875 56 \n",
       "Q 41.65625 56 45.828125 52.96875 \n",
       "Q 50 49.953125 52 44.1875 \n",
       "z\n",
       "\" id=\"DejaVuSans-109\"/>\n",
       "       <path d=\"M 30.609375 48.390625 \n",
       "Q 23.390625 48.390625 19.1875 42.75 \n",
       "Q 14.984375 37.109375 14.984375 27.296875 \n",
       "Q 14.984375 17.484375 19.15625 11.84375 \n",
       "Q 23.34375 6.203125 30.609375 6.203125 \n",
       "Q 37.796875 6.203125 41.984375 11.859375 \n",
       "Q 46.1875 17.53125 46.1875 27.296875 \n",
       "Q 46.1875 37.015625 41.984375 42.703125 \n",
       "Q 37.796875 48.390625 30.609375 48.390625 \n",
       "z\n",
       "M 30.609375 56 \n",
       "Q 42.328125 56 49.015625 48.375 \n",
       "Q 55.71875 40.765625 55.71875 27.296875 \n",
       "Q 55.71875 13.875 49.015625 6.21875 \n",
       "Q 42.328125 -1.421875 30.609375 -1.421875 \n",
       "Q 18.84375 -1.421875 12.171875 6.21875 \n",
       "Q 5.515625 13.875 5.515625 27.296875 \n",
       "Q 5.515625 40.765625 12.171875 48.375 \n",
       "Q 18.84375 56 30.609375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-111\"/>\n",
       "       <path d=\"M 45.40625 46.390625 \n",
       "L 45.40625 75.984375 \n",
       "L 54.390625 75.984375 \n",
       "L 54.390625 0 \n",
       "L 45.40625 0 \n",
       "L 45.40625 8.203125 \n",
       "Q 42.578125 3.328125 38.25 0.953125 \n",
       "Q 33.9375 -1.421875 27.875 -1.421875 \n",
       "Q 17.96875 -1.421875 11.734375 6.484375 \n",
       "Q 5.515625 14.40625 5.515625 27.296875 \n",
       "Q 5.515625 40.1875 11.734375 48.09375 \n",
       "Q 17.96875 56 27.875 56 \n",
       "Q 33.9375 56 38.25 53.625 \n",
       "Q 42.578125 51.265625 45.40625 46.390625 \n",
       "z\n",
       "M 14.796875 27.296875 \n",
       "Q 14.796875 17.390625 18.875 11.75 \n",
       "Q 22.953125 6.109375 30.078125 6.109375 \n",
       "Q 37.203125 6.109375 41.296875 11.75 \n",
       "Q 45.40625 17.390625 45.40625 27.296875 \n",
       "Q 45.40625 37.203125 41.296875 42.84375 \n",
       "Q 37.203125 48.484375 30.078125 48.484375 \n",
       "Q 22.953125 48.484375 18.875 42.84375 \n",
       "Q 14.796875 37.203125 14.796875 27.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-100\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-115\"/>\n",
       "      <use x=\"52.099609\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"79.882812\" xlink:href=\"#DejaVuSans-103\"/>\n",
       "      <use x=\"143.359375\" xlink:href=\"#DejaVuSans-109\"/>\n",
       "      <use x=\"240.771484\" xlink:href=\"#DejaVuSans-111\"/>\n",
       "      <use x=\"301.953125\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"329.736328\" xlink:href=\"#DejaVuSans-100\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"line2d_31\">\n",
       "     <path d=\"M 39.103125 34.976562 \n",
       "L 59.103125 34.976562 \n",
       "\" style=\"fill:none;stroke:#bf00bf;stroke-dasharray:5.55,2.4;stroke-dashoffset:0;stroke-width:1.5;\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_32\"/>\n",
       "    <g id=\"text_15\">\n",
       "     <!-- gradient -->\n",
       "     <g transform=\"translate(67.103125 38.476562)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path d=\"M 41.109375 46.296875 \n",
       "Q 39.59375 47.171875 37.8125 47.578125 \n",
       "Q 36.03125 48 33.890625 48 \n",
       "Q 26.265625 48 22.1875 43.046875 \n",
       "Q 18.109375 38.09375 18.109375 28.8125 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 20.953125 51.171875 25.484375 53.578125 \n",
       "Q 30.03125 56 36.53125 56 \n",
       "Q 37.453125 56 38.578125 55.875 \n",
       "Q 39.703125 55.765625 41.0625 55.515625 \n",
       "z\n",
       "\" id=\"DejaVuSans-114\"/>\n",
       "       <path d=\"M 34.28125 27.484375 \n",
       "Q 23.390625 27.484375 19.1875 25 \n",
       "Q 14.984375 22.515625 14.984375 16.5 \n",
       "Q 14.984375 11.71875 18.140625 8.90625 \n",
       "Q 21.296875 6.109375 26.703125 6.109375 \n",
       "Q 34.1875 6.109375 38.703125 11.40625 \n",
       "Q 43.21875 16.703125 43.21875 25.484375 \n",
       "L 43.21875 27.484375 \n",
       "z\n",
       "M 52.203125 31.203125 \n",
       "L 52.203125 0 \n",
       "L 43.21875 0 \n",
       "L 43.21875 8.296875 \n",
       "Q 40.140625 3.328125 35.546875 0.953125 \n",
       "Q 30.953125 -1.421875 24.3125 -1.421875 \n",
       "Q 15.921875 -1.421875 10.953125 3.296875 \n",
       "Q 6 8.015625 6 15.921875 \n",
       "Q 6 25.140625 12.171875 29.828125 \n",
       "Q 18.359375 34.515625 30.609375 34.515625 \n",
       "L 43.21875 34.515625 \n",
       "L 43.21875 35.40625 \n",
       "Q 43.21875 41.609375 39.140625 45 \n",
       "Q 35.0625 48.390625 27.6875 48.390625 \n",
       "Q 23 48.390625 18.546875 47.265625 \n",
       "Q 14.109375 46.140625 10.015625 43.890625 \n",
       "L 10.015625 52.203125 \n",
       "Q 14.9375 54.109375 19.578125 55.046875 \n",
       "Q 24.21875 56 28.609375 56 \n",
       "Q 40.484375 56 46.34375 49.84375 \n",
       "Q 52.203125 43.703125 52.203125 31.203125 \n",
       "z\n",
       "\" id=\"DejaVuSans-97\"/>\n",
       "       <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-101\"/>\n",
       "       <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-110\"/>\n",
       "       <path d=\"M 18.3125 70.21875 \n",
       "L 18.3125 54.6875 \n",
       "L 36.8125 54.6875 \n",
       "L 36.8125 47.703125 \n",
       "L 18.3125 47.703125 \n",
       "L 18.3125 18.015625 \n",
       "Q 18.3125 11.328125 20.140625 9.421875 \n",
       "Q 21.96875 7.515625 27.59375 7.515625 \n",
       "L 36.8125 7.515625 \n",
       "L 36.8125 0 \n",
       "L 27.59375 0 \n",
       "Q 17.1875 0 13.234375 3.875 \n",
       "Q 9.28125 7.765625 9.28125 18.015625 \n",
       "L 9.28125 47.703125 \n",
       "L 2.6875 47.703125 \n",
       "L 2.6875 54.6875 \n",
       "L 9.28125 54.6875 \n",
       "L 9.28125 70.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-116\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-103\"/>\n",
       "      <use x=\"63.476562\" xlink:href=\"#DejaVuSans-114\"/>\n",
       "      <use x=\"104.589844\" xlink:href=\"#DejaVuSans-97\"/>\n",
       "      <use x=\"165.869141\" xlink:href=\"#DejaVuSans-100\"/>\n",
       "      <use x=\"229.345703\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"257.128906\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"318.652344\" xlink:href=\"#DejaVuSans-110\"/>\n",
       "      <use x=\"382.03125\" xlink:href=\"#DejaVuSans-116\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p6bb7b3c6f2\">\n",
       "   <rect height=\"135.9\" width=\"251.1\" x=\"30.103125\" y=\"7.2\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 324x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "from d2l import torch as d2l\n",
    "\n",
    "x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)\n",
    "y = torch.sigmoid(x)\n",
    "y.backward(torch.ones_like(x))\n",
    "\n",
    "d2l.plot(x.detach().numpy(), [y.detach().numpy(), x.grad.numpy()],\n",
    "         legend=['sigmoid', 'gradient'], figsize=(4.5, 2.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2cc49cd-aa08-4439-babb-5d7e2cefdb97",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "梯度爆炸"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b97a9ac6-93e5-4302-9637-beeff9caf0f8",
   "metadata": {
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "一个矩阵 \n",
      " tensor([[ 0.5039, -0.5113,  0.2666,  0.9192],\n",
      "        [ 0.8290, -0.3719, -0.4758,  0.2095],\n",
      "        [-2.8356,  1.5128,  3.2530,  1.2554],\n",
      "        [-0.4369,  0.0571,  1.3445,  0.5465]])\n",
      "乘以100个矩阵后\n",
      " tensor([[-2.6163e+27, -1.3641e+27,  3.7875e+26,  1.5981e+27],\n",
      "        [ 9.2931e+25,  4.8452e+25, -1.3453e+25, -5.6765e+25],\n",
      "        [-8.5595e+27, -4.4627e+27,  1.2391e+27,  5.2284e+27],\n",
      "        [-3.1940e+27, -1.6653e+27,  4.6239e+26,  1.9510e+27]])\n"
     ]
    }
   ],
   "source": [
    "M = torch.normal(0, 1, size=(4,4))\n",
    "print('一个矩阵 \\n',M)\n",
    "for i in range(100):\n",
    "    M = torch.mm(M,torch.normal(0, 1, size=(4, 4)))\n",
    "\n",
    "print('乘以100个矩阵后\\n', M)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5bbc300-2e11-46b5-97a0-8a94fbb355b0",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "slideshow": {
     "slide_type": "-"
    },
    "tags": []
   },
   "source": [
    "## 层和块\n",
    "\n",
    "我们先回顾一下多层感知机"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6bb8a16d-5ecc-454a-858d-4e9e1e3e2332",
   "metadata": {
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0148, -0.0091, -0.1044, -0.0623,  0.1513,  0.0991,  0.1245, -0.1850,\n",
       "          0.0858,  0.1818],\n",
       "        [-0.0970,  0.0267,  0.0026, -0.0933, -0.0292,  0.1253,  0.2153, -0.0900,\n",
       "          0.0387,  0.0119]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
    "\n",
    "X = torch.rand(2, 20)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a7e60b4-5af9-447d-8461-47455b5427b5",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "`nn.Sequential`定义了一种特殊的`Module`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17417747-7867-4cc9-b6a6-6bea3c5fe307",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "自定义块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c010a3fd-b10d-4862-8326-114deea44c9f",
   "metadata": {
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    # 用模型参数声明层。这里，我们声明两个全连接的层\n",
    "    def __init__(self):\n",
    "        # 调用MLP的父类Module的构造函数来执行必要的初始化。\n",
    "        # 这样，在类实例化时也可以指定其他函数参数，例如模型参数params（稍后将介绍）\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(20, 256)  # 隐藏层\n",
    "        self.out = nn.Linear(256, 10)  # 输出层\n",
    "\n",
    "    # 定义模型的前向传播，即如何根据输入X返回所需的模型输出\n",
    "    def forward(self, X):\n",
    "        # 注意，这里我们使用ReLU的函数版本，其在nn.functional模块中定义。\n",
    "        return self.out(F.relu(self.hidden(X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "514876b3-07c1-4238-b767-935fda3a3941",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "实例化多层感知机的层，然后在每次调用前向传播函数时调用这些层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "97818129-99da-40e8-8079-fcae5484d270",
   "metadata": {
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2511,  0.1825,  0.0906, -0.1014, -0.0818, -0.2257,  0.0343, -0.0502,\n",
       "         -0.1205,  0.1808],\n",
       "        [ 0.0538,  0.2270, -0.0294, -0.0254,  0.0274, -0.1483, -0.1254,  0.0424,\n",
       "         -0.1205,  0.1796]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MLP()\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2647ff40-26ca-4e20-83de-9c92aca16272",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "顺序块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d84d2ab-f4b0-473a-9cfd-688f8b9be647",
   "metadata": {
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class MySequential(nn.Module):\n",
    "    def __init__(self, *args):\n",
    "        super().__init__()\n",
    "        for idx, module in enumerate(args):\n",
    "            # 这里，module是Module子类的一个实例。我们把它保存在'Module'类的成员\n",
    "            # 变量_modules中。module的类型是OrderedDict\n",
    "            self._modules[str(idx)] = module\n",
    "\n",
    "    def forward(self, X):\n",
    "        # OrderedDict保证了按照成员添加的顺序遍历它们\n",
    "        for block in self._modules.values():\n",
    "            X = block(X)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7e7ee471-4a2b-4df2-a313-c83018c7ef92",
   "metadata": {
    "origin_pos": 26,
    "slideshow": {
     "slide_type": "-"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0345,  0.0057,  0.0227, -0.1984, -0.2237, -0.0464, -0.0189, -0.2631,\n",
       "         -0.2752, -0.1415],\n",
       "        [-0.1851,  0.0243, -0.0286, -0.1372, -0.1288, -0.1205,  0.1659, -0.3036,\n",
       "         -0.2170, -0.0545]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13e8e0b-975d-4d24-a517-a1b7ed95712a",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "在前向传播函数中执行代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8303f511-6ec1-4971-bc3a-ebc431b04b46",
   "metadata": {
    "origin_pos": 34,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0362, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class FixedHiddenMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n",
    "        self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
    "        self.linear = nn.Linear(20, 20)\n",
    "\n",
    "    def forward(self, X):\n",
    "        X = self.linear(X)\n",
    "        # 使用创建的常量参数以及relu和mm函数\n",
    "        X = F.relu(torch.mm(X, self.rand_weight) + 1)\n",
    "        # 复用全连接层。这相当于两个全连接层共享参数\n",
    "        X = self.linear(X)\n",
    "        # 控制流\n",
    "        while X.abs().sum() > 1:\n",
    "            X /= 2\n",
    "        return X.sum()\n",
    "\n",
    "net = FixedHiddenMLP()\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9275f5f-ae9c-4ac3-a17e-e497ccca62d2",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "混合搭配各种组合块的方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2dab1c0a-9d6c-43ef-8fd2-e41bace79bf5",
   "metadata": {
    "origin_pos": 37,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.1137, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n",
    "                                 nn.Linear(64, 32), nn.ReLU())\n",
    "        self.linear = nn.Linear(32, 16)\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.linear(self.net(X))\n",
    "\n",
    "chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n",
    "chimera(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5491e9df-11d8-4dfa-a0da-07a0dc06d661",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    },
    "tags": []
   },
   "source": [
    "## 参数管理\n",
    "\n",
    "我们首先看一下具有单隐藏层的多层感知机"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "20a7bea9-9dbf-4c3b-a11a-e78fef1c42c7",
   "metadata": {
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1909],\n",
       "        [-0.2025]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n",
    "X = torch.rand(size=(2, 4))\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d1b0906-6939-4bc0-a93e-c7ea0870f955",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 参数访问"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a91089e7-a386-4915-b2ab-70503785ca50",
   "metadata": {
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('weight', tensor([[-0.2358, -0.2256, -0.1930, -0.0475, -0.0732, -0.3483,  0.0520,  0.1466]])), ('bias', tensor([-0.0579]))])\n"
     ]
    }
   ],
   "source": [
    "print(net[2].state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d80d652-2846-4cb0-acf9-adec725b41bf",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 目标参数\n",
    "\n",
    "注意，每个参数都表示为参数类的一个实例。 要对参数执行任何操作，首先我们需要访问底层的数值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "72fb1f60-b259-46c9-9033-d3ce3d13ef74",
   "metadata": {
    "origin_pos": 10,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.nn.parameter.Parameter'>\n",
      "Parameter containing:\n",
      "tensor([-0.0579], requires_grad=True)\n",
      "tensor([-0.0579])\n"
     ]
    }
   ],
   "source": [
    "print(type(net[2].bias))\n",
    "print(net[2].bias)\n",
    "print(net[2].bias.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e6cc0743-663b-4452-8b65-53a6652e44f1",
   "metadata": {
    "origin_pos": 14,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[2].weight.grad == None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35a83ecb-6df8-4cb6-a0ce-29f521b35d6e",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 一次性访问所有参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cdaf87e6-a880-4f9e-94c3-e7dc48e48878",
   "metadata": {
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
      "('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
     ]
    }
   ],
   "source": [
    "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
    "print(*[(name, param.shape) for name, param in net.named_parameters()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a826965f-f15b-4fb0-8711-f523cde61fb5",
   "metadata": {
    "origin_pos": 21,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0579])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.state_dict()['2.bias'].data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71f979b6-86b5-4fdf-8364-1281369f188c",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 从嵌套块收集参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6cd58e6f-6e87-4b9d-9286-dba5383e035a",
   "metadata": {
    "origin_pos": 25,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3983],\n",
       "        [0.3983]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def block1():\n",
    "    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n",
    "                         nn.Linear(8, 4), nn.ReLU())\n",
    "\n",
    "def block2():\n",
    "    net = nn.Sequential()\n",
    "    for i in range(4):\n",
    "        # 在这里嵌套\n",
    "        net.add_module(f'block {i}', block1())\n",
    "    return net\n",
    "\n",
    "rgnet = nn.Sequential(block2(), nn.Linear(4, 1))\n",
    "rgnet(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2d9b23f-428a-4428-b2fc-a2c58e7c27db",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "设计了网络后，我们看看它是如何工作的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c77d9236-9e3c-4eba-9cac-cd929ceb538f",
   "metadata": {
    "origin_pos": 29,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (block 0): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 1): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 2): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 3): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=4, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(rgnet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8af6c129-6b29-4bf0-afdf-93b71f68cf43",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "通过嵌套列表索引访问"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "05b273c1-1f6c-4d6d-b2df-2d204b6a60b6",
   "metadata": {
    "origin_pos": 33,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.3155,  0.0512,  0.3313,  0.2001, -0.2423, -0.2325, -0.4816, -0.0248])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rgnet[0][1][0].bias.data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3592ddb-97a7-4979-a94f-076a70dc10a5",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 内置初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f41e080d-de2f-42ef-9648-0fe114e949e7",
   "metadata": {
    "origin_pos": 41,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 0.0091,  0.0023, -0.0125,  0.0040]), tensor(0.))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_normal(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.normal_(m.weight, mean=0, std=0.01)\n",
    "        nn.init.zeros_(m.bias)\n",
    "net.apply(init_normal)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ebea5b23-9c84-40c2-a3d7-c1f2b4e77de3",
   "metadata": {
    "origin_pos": 45,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1., 1., 1., 1.]), tensor(0.))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_constant(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 1)\n",
    "        nn.init.zeros_(m.bias)\n",
    "net.apply(init_constant)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0cc312b-e55c-4945-9673-065812c7679e",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "对某些块应用不同的初始化方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dae9a89c-623f-4fd5-b61d-95b7d63d4f69",
   "metadata": {
    "origin_pos": 49,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.5587,  0.2209, -0.0744,  0.3801])\n",
      "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
     ]
    }
   ],
   "source": [
    "def xavier(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(m.weight)\n",
    "def init_42(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 42)\n",
    "\n",
    "net[0].apply(xavier)\n",
    "net[2].apply(init_42)\n",
    "print(net[0].weight.data[0])\n",
    "print(net[2].weight.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "940a6d73-2b27-44ac-a7b5-c24a305e7371",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 自定义初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2bacf98d-96a0-4d1f-abbc-dd2b9677ce89",
   "metadata": {
    "origin_pos": 56,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init weight torch.Size([8, 4])\n",
      "Init weight torch.Size([1, 8])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-8.4986,  0.0000, -0.0000, -0.0000],\n",
       "        [ 6.7884,  0.0000, -9.8570, -6.8247]], grad_fn=<SliceBackward>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def my_init(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        print(\"Init\", *[(name, param.shape)\n",
    "                        for name, param in m.named_parameters()][0])\n",
    "        nn.init.uniform_(m.weight, -10, 10)\n",
    "        m.weight.data *= m.weight.data.abs() >= 5 # |w| < 5 时清零\n",
    "\n",
    "net.apply(my_init)\n",
    "net[0].weight[:2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b28ccf9d-391d-4f47-8865-e07d4fcfb32c",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "直接设置参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a5131518-5f49-4b84-8e52-b0c9f109615b",
   "metadata": {
    "origin_pos": 60,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([42.,  1.,  1.,  1.])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data[:] += 1\n",
    "net[0].weight.data[0, 0] = 42\n",
    "net[0].weight.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2c801fd-1770-4440-9870-b48892f23ef1",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### 参数绑定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "fe7ce342-6ef6-40b5-aef3-2d4fcfc69cdf",
   "metadata": {
    "origin_pos": 65,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([True, True, True, True, True, True, True, True])\n",
      "tensor([True, True, True, True, True, True, True, True])\n"
     ]
    }
   ],
   "source": [
    "# 我们需要给共享层一个名称，以便可以引用它的参数\n",
    "shared = nn.Linear(8, 8)\n",
    "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n",
    "                    shared, nn.ReLU(),\n",
    "                    shared, nn.ReLU(),\n",
    "                    nn.Linear(8, 1))\n",
    "net(X)\n",
    "# 检查参数是否相同\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])\n",
    "net[2].weight.data[0, 0] = 100\n",
    "# 确保它们实际上是同一个对象，而不只是有相同的值\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33d935ce-0cf8-4208-bd13-b5178446bdbc",
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "source": [
    "思考：当参数绑定时，梯度会发生什么情况？"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82fa16e0-4f58-49dd-b9c4-4d06c87b58e6",
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "source": [
    "答案：由于模型参数包含梯度，因此在反向传播期间第二个隐藏层 （即第三个神经网络层）和第三个隐藏层（即第五个神经网络层）的梯度会加在一起。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91739986-4217-4871-831c-84dc63320b06",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "slideshow": {
     "slide_type": "-"
    },
    "tags": []
   },
   "source": [
    "## 自定义层\n",
    "\n",
    "构造一个没有任何参数的自定义层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f0f9d92-6b58-416e-a7bd-5a4b9a42f056",
   "metadata": {
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "class CenteredLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, X):\n",
    "        return X - X.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "db87b3aa-5a0c-4643-b3f4-3d83fbaf9814",
   "metadata": {
    "origin_pos": 6,
    "slideshow": {
     "slide_type": "slide"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = CenteredLayer()\n",
    "layer(torch.FloatTensor([1, 2, 3, 4, 5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ca64c2a-707b-4e90-a7ec-f2ea135206f2",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "将层作为组件合并到更复杂的模型中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3a23c6a9-fea6-412f-a33c-d7dc4c27f68e",
   "metadata": {
    "origin_pos": 14,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.6776e-09, grad_fn=<MeanBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())\n",
    "\n",
    "# 向该网络发送随机数据后，检查均值是否为0\n",
    "Y = net(torch.rand(4, 8))\n",
    "Y.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51ed2384-7439-4321-af2c-a012f3b9fef6",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "带参数的层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e7657660-f475-410b-8ce2-56dbff3b2344",
   "metadata": {
    "origin_pos": 23,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.5454,  1.2766, -0.3547],\n",
       "        [-0.4969, -0.2906, -0.9240],\n",
       "        [ 0.2956, -0.8858,  1.3960],\n",
       "        [-0.3093,  1.2917,  1.4760],\n",
       "        [ 0.3728,  1.4528,  0.7151]], requires_grad=True)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MyLinear(nn.Module):\n",
    "    def __init__(self, in_units, units):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.randn(in_units, units))\n",
    "        self.bias = nn.Parameter(torch.randn(units,))\n",
    "    def forward(self, X):\n",
    "        linear = torch.matmul(X, self.weight.data) + self.bias.data\n",
    "        return F.relu(linear)\n",
    "\n",
    "linear = MyLinear(5, 3)\n",
    "linear.weight"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebb30df8-f029-4fb1-a92f-7ccaa331c1ce",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "使用自定义层直接执行前向传播计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9e4b657e-d9a2-43eb-8abd-8c5529561b85",
   "metadata": {
    "origin_pos": 27,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 1.4069, 0.0000],\n",
       "        [0.0000, 1.1079, 0.0000]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear(torch.rand(2, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0346feaf-d2d8-4d7d-8bd7-d82461a184d3",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "使用自定义层构建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4680cfdf-6cf1-48d0-b835-9126fda45ba7",
   "metadata": {
    "origin_pos": 31,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10.2999],\n",
       "        [ 8.0606]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))\n",
    "net(torch.rand(2, 64))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94facba3-5fd3-4880-9dba-feb7c5d34621",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "slideshow": {
     "slide_type": "-"
    },
    "tags": []
   },
   "source": [
    "## 读写文件\n",
    "\n",
    "加载和保存张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1893269f-ee8e-4afd-ae7f-a67d4ce2d264",
   "metadata": {
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "x = torch.arange(4)\n",
    "torch.save(x, 'x-file')\n",
    "\n",
    "x2 = torch.load('x-file')\n",
    "x2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4dc5887-89df-4560-97a0-da93e373dc27",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "存储一个张量列表，然后把它们读回内存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7cc2777c-fe64-4a36-9fc3-4056514c6485",
   "metadata": {
    "origin_pos": 10,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.zeros(4)\n",
    "torch.save([x, y],'x-files')\n",
    "x2, y2 = torch.load('x-files')\n",
    "(x2, y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46d5e258-b229-49b1-a193-a40082724aa1",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "写入或读取从字符串映射到张量的字典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e1f88ab2-85f8-45ad-8cff-13a33deca4f4",
   "metadata": {
    "origin_pos": 14,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mydict = {'x': x, 'y': y}\n",
    "torch.save(mydict, 'mydict')\n",
    "mydict2 = torch.load('mydict')\n",
    "mydict2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22107304-b2eb-4160-b212-31634da00451",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "加载和保存模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "25c76445-520d-4635-849e-bf28f8ecec56",
   "metadata": {
    "origin_pos": 18,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(20, 256)\n",
    "        self.output = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.output(F.relu(self.hidden(x)))\n",
    "\n",
    "net = MLP()\n",
    "X = torch.randn(size=(2, 20))\n",
    "Y = net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecbd326b-9f20-4497-81d1-cfd51a6b2c1f",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "将模型的参数存储在一个叫做“mlp.params”的文件中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3cdb1197-4a7b-48ba-9ad7-ce6f3b913354",
   "metadata": {
    "origin_pos": 22,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(), 'mlp.params')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c7b6e86-20c7-4d13-b4e7-8159f38dcbfc",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "实例化了原始多层感知机模型的一个备份。\n",
    "直接读取文件中存储的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3dd363e9-1823-4ef5-ab97-201af72c359d",
   "metadata": {
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLP(\n",
       "  (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
       "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clone = MLP()\n",
    "clone.load_state_dict(torch.load('mlp.params'))\n",
    "clone.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f8b8af86-8667-413f-91c2-50dbc05680cf",
   "metadata": {
    "origin_pos": 30,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True, True, True, True, True, True, True],\n",
       "        [True, True, True, True, True, True, True, True, True, True]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_clone = clone(X)\n",
    "Y_clone == Y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fecfb83-9b58-4bcd-97ee-10385bd85371",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "slideshow": {
     "slide_type": "-"
    },
    "tags": []
   },
   "source": [
    "## GPU\n",
    "\n",
    "查看显卡信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "346f4a85-2a82-4ad4-a3f6-1bae6807dbe2",
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sat Mar 12 20:02:56 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 510.47.03    Driver Version: 511.65       CUDA Version: 11.6     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0  On |                  N/A |\n",
      "| N/A   38C    P5    19W /  N/A |   1395MiB /  6144MiB |     14%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|    0   N/A  N/A       112      G   /Xwayland                       N/A      |\n",
      "|    0   N/A  N/A      9974      G   /msedge                         N/A      |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0c59f2e-bb6a-4992-b3b3-1d6a7c1797b3",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "计算设备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9f1e7222-c4ac-4b74-94f9-416a49003441",
   "metadata": {
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cpu'), device(type='cuda'), device(type='cuda', index=1))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "torch.device('cpu'), torch.device('cuda'), torch.device('cuda:1')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fdb7f48-b679-42d9-9a39-849efea5f05f",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "查询可用gpu的数量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "121825ab-a77b-4497-8d1a-cba441765348",
   "metadata": {
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.device_count()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b55283c2-0276-41f5-8592-3ad08d1bda01",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "这两个函数允许我们在不存在所需所有GPU的情况下运行代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "820280b0-fb14-4e5e-93e2-df014ad57fc8",
   "metadata": {
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cuda', index=0),\n",
       " device(type='cpu'),\n",
       " [device(type='cuda', index=0), device(type='cuda', index=1)])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def try_gpu(i=0):  \n",
    "    \"\"\"如果存在，则返回gpu(i)，否则返回cpu()\"\"\"\n",
    "    if torch.cuda.device_count() >= i + 1:\n",
    "        return torch.device(f'cuda:{i}')\n",
    "    return torch.device('cpu')\n",
    "\n",
    "def try_all_gpus():  \n",
    "    \"\"\"返回所有可用的GPU，如果没有GPU，则返回[cpu(),]\"\"\"\n",
    "    devices = [torch.device(f'cuda:{i}')\n",
    "             for i in range(torch.cuda.device_count())]\n",
    "    return devices if devices else [torch.device('cpu')]\n",
    "\n",
    "try_gpu(), try_gpu(10), try_all_gpus()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "420a08ee-e894-46b8-9534-3fb60230ba34",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "查询张量所在的设备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0ee876a1-a7f8-4217-97ee-4813c356a004",
   "metadata": {
    "origin_pos": 20,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([1, 2, 3])\n",
    "x.device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8523060-40ff-47ff-9bd4-9631a5f8f008",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "存储在GPU上"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c9af569d-d446-4976-a04c-599b8c3bd600",
   "metadata": {
    "origin_pos": 24,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1.],\n",
       "        [1., 1., 1.]], device='cuda:0')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.ones(2, 3, device=try_gpu())\n",
    "X"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a3b5790-6ecf-4e9a-bbae-d53a8b61cad6",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "第二个GPU上创建一个随机张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9c96e08b-110f-4bc9-881d-a97faa5da05c",
   "metadata": {
    "origin_pos": 28,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.5473, 0.1942, 0.2213],\n",
       "        [0.5998, 0.5565, 0.0372]], device='cuda:1')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = torch.rand(2, 3, device=try_gpu(1))\n",
    "Y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5479fda-7ffb-4817-9565-c015cb52aaf3",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "要计算`X + Y`，我们需要决定在哪里执行这个操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "98bb6090-d2c0-431b-8872-91310ab36b32",
   "metadata": {
    "origin_pos": 32,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.]], device='cuda:0')\n",
      "tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.]], device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "Z = X.cuda(1)\n",
    "print(X)\n",
    "print(Z)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d999064a-6aa9-441f-8fe9-9ee7c2eb4289",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "现在数据在同一个GPU上（`Z`和`Y`都在），我们可以将它们相加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "646839ab-f817-488c-8239-1ef6599d22ee",
   "metadata": {
    "origin_pos": 35,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.5473, 1.1942, 1.2213],\n",
       "        [1.5998, 1.5565, 1.0372]], device='cuda:1')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y + Z"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccfca3e4-4e0b-479e-b824-bd2cec2c3dc7",
   "metadata": {},
   "source": [
    "变量Z已经在第二个GPU上，如果还是调用Z.cuda(1)将返回Z，而不会复制并分配新内存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "096c1468-9770-46e7-a179-5d6158f06c13",
   "metadata": {
    "origin_pos": 40,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z.cuda(1) is Z"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8bbb16a-514c-4e1a-b689-3571073feb87",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "神经网络与GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "eda84331-029f-4e68-aaee-63245054052a",
   "metadata": {
    "origin_pos": 47,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.2194],\n",
       "        [1.2194]], device='cuda:0', grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(3, 1))\n",
    "net = net.to(device=try_gpu())\n",
    "\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c58b8e18-8d7b-4f37-874f-1441e329c94d",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "确认模型参数存储在同一个GPU上"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e404204c-5e64-4b38-8900-c7b2f5c2ceb9",
   "metadata": {
    "origin_pos": 50,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data.device"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
